# reference: https://github.com/NVlabs/AFNO-transformer
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
import torch.fft


class AFNO2D(nn.Module):
    def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1,
                 hidden_size_factor=1):
        super().__init__()
        assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"

        self.hidden_size = hidden_size
        self.sparsity_threshold = sparsity_threshold
        self.num_blocks = num_blocks
        self.block_size = self.hidden_size // self.num_blocks
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.hidden_size_factor = hidden_size_factor
        self.scale = 0.02

        self.w1 = nn.Parameter(
            self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(
            self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))

    def forward(self, x):
        bias = x

        dtype = x.dtype
        x = x.float()
        B, H, W, C = x.shape

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)

        o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor],
                              device=x.device)
        o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor],
                              device=x.device)
        o2_real = torch.zeros(x.shape, device=x.device)
        o2_imag = torch.zeros(x.shape, device=x.device)

        total_modes = H // 2 + 1
        kept_modes = int(total_modes * self.hard_thresholding_fraction)

        o1_real[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].real,
                         self.w1[0]) - \
            torch.einsum('...bi,bio->...bo', x[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].imag,
                         self.w1[1]) + \
            self.b1[0]
        )

        o1_imag[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].imag,
                         self.w1[0]) + \
            torch.einsum('...bi,bio->...bo', x[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes].real,
                         self.w1[1]) + \
            self.b1[1]
        )

        o2_real[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = (
                torch.einsum('...bi,bio->...bo',
                             o1_real[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], self.w2[0]) - \
                torch.einsum('...bi,bio->...bo',
                             o1_imag[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], self.w2[1]) + \
                self.b2[0]
        )

        o2_imag[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes] = (
                torch.einsum('...bi,bio->...bo',
                             o1_imag[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], self.w2[0]) + \
                torch.einsum('...bi,bio->...bo',
                             o1_real[:, total_modes - kept_modes:total_modes + kept_modes, :kept_modes], self.w2[1]) + \
                self.b2[1]
        )

        x = torch.stack([o2_real, o2_imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = torch.view_as_complex(x)
        x = x.reshape(B, H, W // 2 + 1, C)
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm="ortho")
        x = x.type(dtype)

        return x + bias


class AFNONet(nn.Module):
    def __init__(
            self,
            img_size=(16, 16),
            embed_dim=8,
            num_blocks=4,
            sparsity_threshold=0.01,
            hard_thresholding_fraction=1.0,
    ):
        super().__init__()
        self.img_size = img_size
        self.embed_dim = embed_dim
        self.num_blocks = num_blocks
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.norm1 = norm_layer(embed_dim)
        self.filter = AFNO2D(embed_dim, num_blocks, sparsity_threshold, hard_thresholding_fraction)
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(embed_dim)

        self.h = img_size[0]
        self.w = img_size[1]


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}


    def forward(self, x):

        B = x.shape[0]
        x = x.reshape(B, self.h, self.w, self.embed_dim)

        x = self.filter(x)

        x = x.reshape(B, self.embed_dim, self.h, self.w)

        return x



